{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"background-color: #04D7FD; padding: 20px; text-align: left;\">\n",
    "    <h1 style=\"color: #000000; font-size: 30px; margin: 0;\">data-prep-kit planning agent</h1>   \n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -qq -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import HTML\n",
    "task = \"Process the provided PDF dataset to identify and extract only documents that don't contain inappropriate language. Remove the duplications.\"\n",
    "HTML(f\"<p><span style='color:blue; font-weight:bold; font-size:14.0pt;'>TASK: {task}</span></p>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import os\n",
    "\n",
    "from llm_utils.logging import prep_loggers\n",
    "os.environ[\"LLM_LOG_PATH\"] = \"./logs/llm_log.txt\"\n",
    "prep_loggers(\"llm=INFO\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The tools in DPK agents are the transforms.\n",
    "# Each tool is described as json dictionary with its name, description, input parameters, and how to import it.\n",
    "# The list of the tools exists in llm_utils/tools.py file.\n",
    "from llm_utils.dpk.tools import *\n",
    "print(tools_json)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is an example of a plan for a simple task. It is possed to the prompt to enhance the planning results.\n",
    "from llm_utils.dpk.examples import *\n",
    "print(example_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is a string that contains several constraints on the order of the tools in the plan.\n",
    "# It is a free text and can be found in llm_utils/constraints.py file.\n",
    "from llm_utils.dpk.constraints import *\n",
    "print(constraints)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define LLM models\n",
    "\n",
    "We have have tested our project with the following LLM execution frameworks: [Watsonx](https://www.ibm.com/watsonx), [Replicate](https://replicate.com/), and locally running [Ollama](https://ollama.com/).\n",
    "To use one of the frameworks uncomment its part in the cell below while commenting out the other frameworks.\n",
    "Please note that the notebooks have been tested with specific Large Language Models (LLMs) that are mentioned in the cell, and due to the inherent nature of LLMs, using a different model may not produce the same results.\n",
    "\n",
    "- To use Replicate:\n",
    "  - Obtain Replicate API token\n",
    "  - Store the following value in the `.env` file located in your project directory:\n",
    "    ```\n",
    "        REPLICATE_API_TOKEN=<your Replicate API token>\n",
    "    ```\n",
    "- To use Ollama: \n",
    "  - Download [Ollama](https://ollama.com/download).\n",
    "  - Download one of the supported [models](https://ollama.com/search). We tested with `llama3.3` model.\n",
    "  - update the `model_ollama_*` names if needed.\n",
    "- To use Watsonx:\n",
    "  - Register for Watsonx\n",
    "  - Obtain its API key\n",
    "  - Store the following values in the `.env` file located in your project directory:\n",
    "    ```\n",
    "        WATSONX_URL=<WatsonX entry point, e.g. https://us-south.ml.cloud.ibm.com>\n",
    "        WATSON_PROJECT_ID=<your Watsonx project ID>\n",
    "        WATSONX_APIKEY=<your Watsonx API key>\n",
    "    ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llm_utils.models import getChatLLM\n",
    "from dotenv import dotenv_values\n",
    "\n",
    "# watsonx part \n",
    "# config = dotenv_values(\"./.env\")\n",
    "# model_watsonx_id1 = \"ibm-granite/granite-3.1-8b-instruct\"\n",
    "# model_watsonx_id2 = \"meta-llama/llama-3-1-70b-instruct\"\n",
    "# model_watsonx_id3 = \"meta-llama/llama-3-3-70b-instruct\"\n",
    "# model_watsonx_id4 = \"ibm/granite-34b-code-instruct\"\n",
    "\n",
    "# llm_plan = getChatLLM(\"watsonx\", model_watsonx_id2, config)\n",
    "# llm_judge = getChatLLM(\"watsonx\", model_watsonx_id2, config)\n",
    "# llm_generate = getChatLLM(\"watsonx\", model_watsonx_id2, config)\n",
    "\n",
    "# # ollama part\n",
    "# model_ollama = \\\"llama3.3\\\"\\n\",\n",
    "# llm_plan = getChatLLM(\\\"ollama\\\", model_ollama);\\n\",\n",
    "# llm_judge = getChatLLM(\\\"ollama\\\", model_ollama)\\n\",\n",
    "# llm_generate = getChatLLM(\\\"ollama\\\", model_ollama)\"\n",
    "\n",
    "# replicate part\n",
    "config = dotenv_values(\"./.env\")\n",
    "# You can use different llm models\n",
    "model_replicate_id1 = \"meta/meta-llama-3-70b-instruct\"\n",
    "llm_plan = getChatLLM(\"replicate\", model_replicate_id1, config)\n",
    "llm_judge = getChatLLM(\"replicate\", model_replicate_id1, config)\n",
    "llm_generate = getChatLLM(\"replicate\", model_replicate_id1, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import StateGraph, END\n",
    "from llm_utils.agent_helpers import *\n",
    "from llm_utils.prompts.planner_prompt import *\n",
    "from llm_utils.prompts.judge_prompt import *\n",
    "from llm_utils.prompts.generate_prompt import *\n",
    "from llm_utils.dpk.tools import *\n",
    "from llm_utils.dpk.examples import *\n",
    "from llm_utils.dpk.constraints import *\n",
    "from functools import partial\n",
    "\n",
    "\n",
    "# Create the graph\n",
    "workflow = StateGraph(State)\n",
    "\n",
    "# Add nodes\n",
    "workflow.add_node(\"planner\", partial(planner, prompt=planner_prompt_str, tools=tools_json, example=example_task1, context=constraints, llm=llm_plan))\n",
    "workflow.add_node(\"judge\", partial(judge, prompt=judge_prompt_str_dpk, tools=tools_json, context=constraints, llm=llm_judge))\n",
    "workflow.add_node(\"user_review\", get_user_review)\n",
    "workflow.add_node(\"code generator\", partial(generator, prompt=generate_prompt_str_with_example, llm=llm_generate))\n",
    "workflow.add_node(\"code validator\", code_validator_noop)\n",
    "\n",
    "# Add edges\n",
    "workflow.set_entry_point(\"planner\")\n",
    "workflow.add_edge(\"code generator\", \"code validator\")\n",
    "workflow.add_edge(\"code validator\", END)\n",
    "\n",
    "# Add conditional edges from judge\n",
    "workflow.add_conditional_edges(\n",
    "    \"judge\",\n",
    "    is_plan_OK,\n",
    "    {\n",
    "        False: \"planner\",  # If needs revision, go back to planner\n",
    "        True: \"user_review\"  # If plan is good, proceed to user review\n",
    "    }\n",
    ")\n",
    "\n",
    "# Add conditional edges from planner\n",
    "workflow.add_conditional_edges(\n",
    "    \"planner\",\n",
    "    need_judge,\n",
    "    {\n",
    "        True: \"judge\",  # If needs revision, go back to planner\n",
    "        False: \"user_review\"  # If plan is good, proceed to user review\n",
    "    }\n",
    ")\n",
    "\n",
    "workflow.add_conditional_edges(\n",
    "    \"user_review\",\n",
    "    is_user_review_OK,\n",
    "    {\n",
    "        False: \"planner\",  # If needs revision, go back to planner\n",
    "        True: \"code generator\",\n",
    "    }\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "from IPython.display import Image, display\n",
    "\n",
    "app = workflow.compile()\n",
    "mermaid_syntax = app.get_graph().draw_mermaid()\n",
    "response = requests.post(\"https://kroki.io/mermaid/png\", \n",
    "                        data=mermaid_syntax, \n",
    "                        headers={'Content-Type': 'text/plain'})\n",
    "\n",
    "display(Image(response.content))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run the graph\n",
    "initial_state = {\n",
    "    \"task\": task,\n",
    "    \"context\": \"\",\n",
    "    \"plan\": [\"still no plan\"],\n",
    "    \"planning_attempts\": 0,\n",
    "    \"feedback\": \"Still no review\",\n",
    "    \"needs_revision\": \"\",\n",
    "    \"need_judge\": True,\n",
    "}\n",
    "\n",
    "state = initial_state\n",
    "\n",
    "for output in app.stream(state):\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
